// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#define BCAST_LLKOP EltwiseBinaryType::ELWMUL
#define BCAST_DIM BroadcastType::COL

#include "compute_kernel_api.h"
#include "compute_kernel_api/bcast.h"
#include "compute_kernel_api/eltwise_binary.h"
#include "compute_kernel_api/layernorm.h"
#include "compute_kernel_api/eltwise_binary_sfpu.h"
#include "compute_kernel_api/tile_move_copy.h"
#include "compute_kernel_api/welford.h"
#include "compute_kernel_api/eltwise_unary/eltwise_unary.h"
#include "compute_kernel_api/eltwise_unary/rsqrt.h"
#include "compute_kernel_api/transpose_wh.h"
#include "compute_kernel_api/transpose_wh_dest.h"

namespace NAMESPACE {

class cb_ping_pong {
public:
    cb_ping_pong(tt::CBIndex in_ping, tt::CBIndex in_pong) : ping(in_ping), pong(in_pong) {}

    inline auto read() { return ping; }

    inline auto write() { return pong; }

    void swap() { std::swap(ping, pong); }

private:
    tt::CBIndex ping;
    tt::CBIndex pong;
};

void MAIN {
    uint32_t NCHt = get_arg_val<uint32_t>(0);
    constexpr uint32_t Wt = get_compile_time_arg_val(0);
    constexpr uint32_t blk = get_compile_time_arg_val(1);
    constexpr uint32_t do_gamma = get_compile_time_arg_val(2);
    constexpr uint32_t do_beta = get_compile_time_arg_val(3);
    constexpr bool FLOAT32_DTYPE = get_compile_time_arg_val(4) == 1;
    constexpr uint32_t W = get_compile_time_arg_val(5);
    constexpr uint32_t tile_width = get_compile_time_arg_val(6);
    constexpr bool fuse_pre_add = static_cast<bool>(get_compile_time_arg_val(8));

    // Note that the entire W dimension must fit in the intermed0 CB for this kernel to be correct
    constexpr auto cb_scaler = tt::CBIndex::c_2;  // single tile generated by the reader
    constexpr auto cb_eps = tt::CBIndex::c_3;     // single tile generated by the reader
    constexpr auto cb_in = tt::CBIndex::c_0;      // input x or a for fused pre-add (x=a+b)
    constexpr auto cb_inb = tt::CBIndex::c_1;     // input b for fused pre-add
    constexpr auto cb_out = tt::CBIndex::c_16;    // output
    constexpr auto cb_gamma = tt::CBIndex::c_5;
    constexpr auto cb_beta = tt::CBIndex::c_6;
    uint32_t cb_xmm = tt::CBIndex::c_24;                   // x - E[x]
    constexpr auto cb_ex = tt::CBIndex::c_18;              // E[x]
    constexpr auto cb_ex2 = tt::CBIndex::c_19;             // Var[x] = E[(x-E[x])^2]
    constexpr auto cb_ex2pe = tt::CBIndex::c_21;           // Var[x]+ε
    constexpr auto cb_fusion = tt::CBIndex::c_22;          // stream gamma/beta
    constexpr auto cb_interm_pre_add = tt::CBIndex::c_23;  // intermediate for fused pre-add
    auto cb_welford_ping = tt::CBIndex::c_4;               // Ping-pong buffer for storing Welford's mean/var
    auto cb_welford_pong = tt::CBIndex::c_7;               // Ping-pong buffer for storing Welford's mean/var
    constexpr auto cb_result_or_input = fuse_pre_add ? cb_interm_pre_add : cb_in;

    constexpr auto scaler0 = 0;
    constexpr uint32_t onetile = 1;
    constexpr uint32_t twotiles = 2;

    // Initialize the hardware based on the first op
    // that will be done
    if constexpr (fuse_pre_add) {
        // Init for x = in + b
        binary_op_init_common(cb_in, cb_inb, cb_interm_pre_add);
    } else {
        // Init for transpose
        constexpr auto first_out_cb = cb_ex;
        unary_op_init_common(cb_in, first_out_cb);
    }

    cb_wait_front(cb_scaler, onetile);  // comes from the reader
    cb_wait_front(cb_eps, onetile);     // comes from the reader

    constexpr uint32_t dst0 = 0;  // Input tile for Welford's
    constexpr uint32_t dst1 = 1;  // Mean tile for Welford's
    constexpr uint32_t dst2 = 2;  // Variance tile for Welford's

    auto cb_welford = cb_ping_pong(cb_welford_ping, cb_welford_pong);
    for (uint32_t ncht = 0; ncht < NCHt; ncht++) {
        // =====================================
        // First pass over the input.
        // Calculate E[x] and Var[x]
        // =====================================
        for (uint32_t wt = 0; wt < Wt; wt += blk) {
            if constexpr (fuse_pre_add) {
                // Fused pre-add
                reconfig_data_format(cb_in, cb_inb);
                add_tiles_init(cb_in, cb_inb);
                tile_regs_acquire();
                for (uint32_t j = 0; j < blk; j++) {
                    cb_wait_front(cb_in, j + 1);
                    cb_wait_front(cb_inb, j + 1);
                    add_tiles(cb_in, cb_inb, j, j, j);
                }
                tile_regs_commit();

                cb_pop_front(cb_inb, blk);

                // Pack to intermediate CB (needed
                // to workaround transpose_wh_dest bug)
                pack_reconfig_data_format(cb_interm_pre_add);
                cb_reserve_back(cb_interm_pre_add, blk);
                tile_regs_wait();
                for (uint32_t j = 0; j < blk; j++) {
                    pack_tile(j, cb_interm_pre_add);
                }
                tile_regs_release();
                cb_push_back(cb_interm_pre_add, blk);
            }

            tile_regs_acquire();
            if (wt > 0) {
                // Copy previous accumulated (row tiles)
                // mean and variance to dest regs
                cb_wait_front(cb_welford.read(), twotiles);

                reconfig_data_format_srca(cb_welford.read());
                copy_tile_to_dst_init_short(cb_welford.read());
                copy_tile(cb_welford.read(), 0, dst1);
                copy_tile(cb_welford.read(), 1, dst2);

                cb_pop_front(cb_welford.read(), twotiles);
            }

            // Process block of Welford's
            // Shouldn't need a full init, but there's a bug
            // in short init that causes accuracy to drop
            reconfig_data_format_srca(cb_result_or_input);
            transpose_wh_init(cb_result_or_input, cb_result_or_input);
            welford_init();
            for (uint32_t j = 0; j < blk; j++) {
                cb_wait_front(cb_result_or_input, j + 1);
                transpose_wh_tile(cb_result_or_input, j, dst0);
                welford_tile<dst0, dst1, dst2, true, 0>((wt + j) * tile_width, W, 0, {});
            }
            tile_regs_commit();

            // Pop the input or result CB
            if constexpr (fuse_pre_add) {
                cb_pop_front(cb_in, blk);
                cb_pop_front(cb_interm_pre_add, blk);
            } else {
                cb_pop_front(cb_in, blk);
            }

            // Pack dst1 and dst2 into CBs
            // Leave as row tiles
            tile_regs_wait();
            cb_reserve_back(cb_welford.write(), twotiles);
            pack_reconfig_data_format(cb_welford.write());
            pack_tile_block(dst1, cb_welford.write(), twotiles);
            tile_regs_release();

            cb_push_back(cb_welford.write(), twotiles);

            cb_welford.swap();
        }

        // Transpose mean and variance back to
        // columns and pack back to CBs
        cb_wait_front(cb_welford.read(), twotiles);
        transpose_wh_init_short(cb_welford.read());
        reconfig_data_format_srca(cb_welford.read());
        tile_regs_acquire();
        transpose_wh_tile(cb_welford.read(), 0, dst1);
        transpose_wh_tile(cb_welford.read(), 1, dst2);

        tile_regs_commit();
        tile_regs_wait();

        cb_reserve_back(cb_ex, onetile);
        cb_reserve_back(cb_ex2, onetile);
        pack_reconfig_data_format(cb_ex);
        pack_tile(dst1, cb_ex);
        pack_reconfig_data_format(cb_ex2);
        pack_tile(dst2, cb_ex2);
        cb_push_back(cb_ex, onetile);
        cb_push_back(cb_ex2, onetile);
        tile_regs_release();

        // =====================================
        // Calculate 1/(√(Var(X) + ε))
        // =====================================
        tile_regs_acquire();
        tile_regs_wait();

        cb_wait_front(cb_ex2, onetile);

        reconfig_data_format(cb_ex2, cb_eps);

        add_tiles_init(cb_ex2, cb_eps);
        add_tiles(cb_ex2, cb_eps, 0, 0, dst0);

        rsqrt_tile_init();
        rsqrt_tile(dst0);

        tile_regs_commit();

        cb_reserve_back(cb_ex2pe, onetile);
        pack_tile(dst0, cb_ex2pe);
        cb_push_back(cb_ex2pe, onetile);
        tile_regs_release();

        cb_pop_front(cb_ex2, onetile);
        cb_wait_front(cb_ex2pe, onetile);

        // broadcasts the tile since cb_ex2pe is a column vector that contains the important data
        tile_regs_acquire();
        tile_regs_wait();
        reconfig_data_format_srca(cb_ex2pe);
        unary_bcast_init<BroadcastType::COL>(cb_ex2pe, cb_ex2pe);
        unary_bcast<BroadcastType::COL>(cb_ex2pe, 0, dst0);
        cb_pop_front(cb_ex2pe, onetile);
        tile_regs_commit();
        pack_tile(dst0, cb_ex2pe);
        tile_regs_release();
        cb_push_back(cb_ex2pe, onetile);

        // =====================================
        // Second pass over the input.
        // Computes the final value:
        //    x-E[x]
        //(---------------*𝛄)+ß
        //  √(Var(x)+ε)
        // =====================================
        for (uint32_t wt = 0; wt < Wt; wt += blk) {
            tile_regs_acquire();
            tile_regs_wait();
            cb_reserve_back(cb_out, blk);
            cb_wait_front(cb_ex, onetile);
            cb_wait_front(cb_in, blk);
            reconfig_data_format(cb_in, cb_ex);
            sub_bcast_cols_init_short(cb_in, cb_ex);
            // x-E[x]
            for (uint32_t j = 0; j < blk; j++) {
                sub_tiles_bcast_cols(cb_in, cb_ex, j, 0, j);
            }
            cb_pop_front(cb_in, blk);
            reconfig_data_format_srca(cb_in, cb_ex2pe);

            if constexpr (fuse_pre_add) {
                // Fuse in = in + b
                cb_wait_front(cb_inb, blk);
                reconfig_data_format_srca(cb_ex2pe, cb_inb);
                binary_dest_reuse_tiles_init<ELWADD, EltwiseBinaryReuseDestType::DEST_TO_SRCB>(cb_inb);
                for (uint32_t j = 0; j < blk; j++) {
                    binary_dest_reuse_tiles<ELWADD, EltwiseBinaryReuseDestType::DEST_TO_SRCB>(cb_inb, j, j);
                }
                cb_pop_front(cb_inb, blk);
                reconfig_data_format_srca(cb_inb, cb_ex2pe);
            }

            // Multiply by 1/(√(Var(X) + ε))
            cb_wait_front(cb_ex2pe, 1);
            binary_dest_reuse_tiles_init<ELWMUL, EltwiseBinaryReuseDestType::DEST_TO_SRCB>(cb_ex2pe);
            for (uint32_t j = 0; j < blk; j++) {
                binary_dest_reuse_tiles<ELWMUL, EltwiseBinaryReuseDestType::DEST_TO_SRCB>(cb_ex2pe, 0, j);
            }
            tile_regs_commit();
            if constexpr (!(do_gamma == 1 or do_beta == 1)) {
                cb_xmm = cb_out;
            }
            pack_reconfig_data_format(cb_xmm);
            cb_reserve_back(cb_xmm, blk);
            for (uint32_t j = 0; j < blk; j++) {
                pack_tile(j, cb_xmm);
            }
            cb_push_back(cb_xmm, blk);
            tile_regs_release();

            if constexpr (do_gamma == 1) {
                // Multiply by gamma
                tile_regs_acquire();
                tile_regs_wait();
                reconfig_data_format(cb_xmm, cb_gamma);
                if constexpr (!do_beta) {
                    pack_reconfig_data_format(cb_out);
                }
                cb_wait_front(cb_gamma, blk);
                cb_wait_front(cb_xmm, blk);
                mul_bcast_rows_init_short(cb_xmm, cb_gamma);
                for (uint32_t j = 0; j < blk; j++) {
                    mul_tiles_bcast_rows(cb_xmm, cb_gamma, j, j, j);
                }
                tile_regs_commit();
                cb_pop_front(cb_gamma, blk);
                cb_pop_front(cb_xmm, blk);
                if constexpr (!do_beta) {
                    cb_reserve_back(cb_out, blk);
                    for (uint32_t j = 0; j < blk; j++) {
                        pack_tile(j, cb_out);
                    }
                    cb_push_back(cb_out, blk);
                } else {
                    cb_reserve_back(cb_xmm, blk);
                    for (uint32_t j = 0; j < blk; j++) {
                        pack_tile(j, cb_xmm);
                    }
                    cb_push_back(cb_xmm, blk);
                }

                tile_regs_release();
            }
            if constexpr (do_beta == 1) {
                // Add beta
                tile_regs_acquire();
                tile_regs_wait();
                reconfig_data_format(cb_xmm, cb_beta);
                pack_reconfig_data_format(cb_out);
                cb_wait_front(cb_beta, blk);
                cb_wait_front(cb_xmm, blk);
                add_bcast_rows_init_short(cb_xmm, cb_beta);
                for (uint32_t j = 0; j < blk; j++) {
                    add_tiles_bcast_rows(cb_xmm, cb_beta, j, j, j);
                }
                tile_regs_commit();
                cb_pop_front(cb_beta, blk);
                cb_pop_front(cb_xmm, blk);
                cb_reserve_back(cb_out, blk);
                for (uint32_t j = 0; j < blk; j++) {
                    pack_tile(j, cb_out);
                }
                tile_regs_release();
                cb_push_back(cb_out, blk);
            }
        }

        cb_xmm = tt::CBIndex::c_24;  // x minus mean
        cb_pop_front(cb_ex2pe, onetile);
        cb_pop_front(cb_ex, onetile);
    }  // NCHt loop
}
}  // namespace NAMESPACE
